{
"cells": [
{
"cell_type": "markdown",
"id": "6fc60e51",
"metadata": {},
"source": [
"# GRU Decoder with Masked Self-Attention\n",
"\n",
"This notebook introduces **masked self-attention inside the decoder** before running the GRU update. The encoder from the\n",
"previous step is reused, so the decoder now reasons about previously generated tokens via causal masking while still relying\n",
"on the recurrent state for temporal continuity.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36f1ab44",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"from tsv_seq2seq_data import TSVSeq2SeqData\n",
"\n",
"import importlib\n",
"import hw7\n",
"importlib.reload(hw7)\n",
"from hw7 import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "509be716",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from tsv_seq2seq_data import TSVSeq2SeqData\n",
"data_path = os.path.expanduser('~/Dropbox/CS6140/data/sentence_pairs_large.tsv')\n",
"data = TSVSeq2SeqData(\n",
" path=data_path,\n",
" batch_size=512,\n",
" num_steps=25,\n",
" min_freq=2,\n",
" val_frac=0.05,\n",
" test_frac=0.0,\n",
" sample_percent=1,\n",
")\n",
"\n",
"embed_size = 320\n",
"num_hiddens = 320 \n",
"num_blks = 3 \n",
"num_layers =3\n",
"dropout = 0.35 \n",
"num_heads = 4\n",
"\n",
"encoder = SelfAttentionAugmentedEncoder(len(data.src_vocab), embed_size, num_hiddens, num_layers,\n",
" num_heads=num_heads, dropout=dropout)\n",
"decoder = SelfAttentiveGRUDecoder(len(data.tgt_vocab), embed_size,\n",
" num_hiddens, num_layers, num_heads=num_heads, dropout=dropout)\n",
"model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''], lr=0.001)\n",
"trainer = d2l.Trainer(max_epochs=15, gradient_clip_val=1, num_gpus=1)\n",
"trainer.fit(model, data)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a768ece6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"vamos . => let's get free. | reference: go . | BLEU: 0.000\n",
"me perdi . => i will give me a ticket. | reference: i got lost . | BLEU: 0.000\n",
"esta tranquilo . => this | reference: he is calm . | BLEU: 0.000\n",
"estoy en casa . => i'm in the classroom. | reference: i am at home . | BLEU: 0.000\n",
"donde esta el tren ? => where is the train tomorrow? | reference: where is the train ? | BLEU: 0.832\n",
"necesito ayuda urgente . => i need to pay you free. | reference: i need urgent help . | BLEU: 0.386\n",
"ayer llovio mucho en la ciudad . => yesterday i was a lot of summer in the summer. | reference: it rained a lot in the city yesterday . | BLEU: 0.485\n",
"los ninos estan jugando en el parque . => the playing on the new room. | reference: the children are playing in the park . | BLEU: 0.000\n",
"ella quiere aprender a hablar ingles muy bien . => she wants to learn to learn a lot of spanish. | reference: she wants to learn to speak english very well . | BLEU: 0.577\n",
"cuando llegara el proximo tren a madrid ? => when he got to the he started to be | reference: when will the next train to madrid arrive ? | BLEU: 0.000\n"
]
}
],
"source": [
"# examples = ['necesito ayuda urgente .', 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .', 'cuando llegara el proximo tren a madrid ?']\n",
"# references = ['i need urgent help .', 'it rained a lot in the city yesterday .', 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n",
"\n",
"# preds, _ = model.predict_step(\n",
"# data.build(examples, references), d2l.try_gpu(), data.num_steps)\n",
"# for src, tgt, pred in zip(examples, references, preds):\n",
"# translation = []\n",
"# for token in data.tgt_vocab.to_tokens(pred):\n",
"# if token == '':\n",
"# break\n",
"# translation.append(token)\n",
"# print(f\"{src} => {' '.join(translation)} | reference: {tgt}\")\n",
"\n",
"\n",
"examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?', 'necesito ayuda urgente .',\n",
" 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .',\n",
" 'cuando llegara el proximo tren a madrid ?']\n",
"\n",
"references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?',\n",
" 'i need urgent help .', 'it rained a lot in the city yesterday .',\n",
" 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n",
"\n",
"preds, _ = model.predict_step(\n",
" data.build(examples, references), d2l.try_gpu(), data.num_steps)\n",
"for src, tgt, pred in zip(examples, references, preds):\n",
" translation = []\n",
" for token in data.tgt_vocab.to_tokens(pred):\n",
" if token == '':\n",
" break\n",
" translation.append(token)\n",
" \n",
" hypo = ' '.join(translation)\n",
" print(f\"{src} => {hypo} | reference: {tgt} | BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18bd1a50",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python_mac_d2l",
"language": "python",
"name": "d2l"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.14.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}